久别重逢话双塔
回归前言
大约半年前,换了新工作,从此开始了为了丰富海外人民的业余文化生活,而殚精竭虑、绞尽脑汁的日子,也就没时间、没心情打理我的知乎和微信公众号了,两个地方都荒得长草了。那今天怎么又回来了?原因有三:
的确是舍不得我这辛辛苦苦攒下的粉。每个关注我的朋友,都是靠我爆肝码字、做视频换来的,是对我过去辛苦创作的一种肯定,实在舍不得半途而废。何况还有同学私信我追更,使我深受感动。 知乎最近应该搞了一个“返航计划”,以唤醒流失作者,算是给我的回归增加了一个契机。 促使我下定决心回归的最后一根稻草是,唉,最近比较烦、比较烦、比较烦。人家梁朝伟心烦的时候,能去伦敦喂鸽子,而我只能写篇文章,娱众并自娱,在虚拟世界中慰藉心灵。
书归正传,今天咱们说说双塔模型。可能是因为我的网名有一个“塔”字,换到新东家后,就让我负责双塔召回+粗排上的工作,从此开始了一段与“双塔”爱恨交加、一言难尽的日子。有感而发,有了下面的文章。
正文开始之前,先声明两点:
双塔是“召回”+“粗排”的绝对主力模型。但是要让双塔在召回、粗排中发挥作用,带来收益,只改进双塔结构是远远不够的。如何采样以减少“样本选择偏差”、如何保证上下游目标一致性、如何在双塔中实现多任务间的信息转移...,都是非常重要的课题。但是受篇幅限制,本文只聚集于双塔模型结构上的改进。 市面上关于双塔改进的文章有很多,本文不会一一罗列这些改进的细节。遵循本人文章的一贯风格,本文将为读者梳理这些改进背后的发展脉络,了解这些改进“为什么这样做”,希望能够激发读者在“改进双塔”上新的灵感。至于“怎么做”,请感兴趣的读者稳步原文。
双塔分离:成也萧何,败也萧何
双塔的模型结构很简单,
训练的时候 将用户侧的信息喂入一个DNN(aka, user tower),最终得到一个user embedding 将物料侧的信息喂入一个DNN(aka, item tower),最终得到一个item embedding 拿user embedding与item embedding,做点积或cosine,得到logit,代表user & item之间的匹配程度 设计loss,将user tower, item tower, 和各种特征的embedding,都训练出来 注意,虽然训练流程类似,但是“双塔召回”与“双塔粗排”所需要的负样本,截然不同,详情见《负样本为王:评Facebook的向量化召回算法》。 双塔用作召回,线上预测时 离线、周期性、批量将item信息(i.e., 十万、百万级别)喂入item tower,得到item embedding 得到的item embedding,导入FAISS,建立索引 线上接到用户请求后,将user信息喂入user tower,得到user embedding 拿得到的user embedding,在faiss中做近邻搜索(ANN),得到与user embedding相邻item,作为召回内容返回 双塔用作粗排,线上预测时 与双塔召回时一样,item embedding依然是离线、周期性、批量生成。 由于一个user request中的候选item set,已经由召回阶段缩小到可接受的范围,所以离线生成的item embedding无需灌入faiss建立索引了,只需要以KV方式存储起来 线上接到粗排请求后,将user信息喂入user tower,得到user embedding 拿粗排请求中的candidate item id(千级别)去KV库中检索出对应的item embedding 拿user embedding与检索出来的item embedding,逐一做点积或cosine得到user & item的匹配度 将candidate item按与user匹配度降序排列,top N candidate item喂入下游精排。
综上所述,我们发现双塔模型最大的特点就是“双塔分离”。只不过“部署时分离”成了双塔最大的优点,而“训练时分离”成了制约双塔性能的最大因素。
部署时分离 由于item tower完全不依赖于user信息,所以海量的item embedding可以周期性、批量、离线生成,大大减轻了线上serving的压力 由于user tower完全不依赖于item信息,所以无论候选集是几千(粗排)或十万级、百万级(召回),user embedding只需要生成一遍 反观精排模型,由于从最底层user & item信息就需要开始产生交叉,“难舍难分”。所以user信息必须与每一条candidate item过一遍精排模型,从而限制了精排候选集的规模 训练时分离 user信息只能喂入user tower, item信息只能喂入item tower,没有地方喂入user & item之间的交叉特征 user侧信息与item侧信息,只有唯一一次交叉机会,就是在双塔生成各自的embedding之后的那次点积或cosine。但是这时参与交叉的user/item embedding,已经是高度浓缩的了。一些细节信息已经损失,永远失去了与对侧信息交叉的机会。 为了线上快速serving,交叉只能是简单的dot或cosine。一些复杂的、依赖于底层信息的交叉结构,比如target item对user action history的attention,也在双塔中找不到位置。
综上所述,“双塔分离”的结构,既是保障线上快速serving的优点,也是不能使用交叉特征与结构、导致两侧信息交叉过晚、制约模型表达能力的最大缺点。“线上快速serving”正好对召回、粗排这种“大候选集”场景的胃口,而由于后面还有能力强大的精排,所以“模型表达能力弱”的缺点,也能够为召回、粗排所容忍。因此,双塔模型成为召回+粗排的主流模型,几乎是粗排的不二选择。
但是,你也知道,互联网人追逐OKR的脚步是永无停歇的。在“做大做强,再创辉煌”的口号激(nei)励(juan)下,说好的可容忍的“表达能力弱”变得面目可憎,卷起来的互联网人从来不做选择,“线上快速serving的能力”与“强悍的模型表达能力”,我都要!!!
双塔改建计划
综上所以,双塔最大的缺点就在于,user&item两侧信息交叉得太晚,等到最终能够通过dot或cosine交叉的时候,user & item embedding已经是高度浓缩的了,一些细粒度的信息已经在塔中被损耗掉,永远失去了与对侧信息交叉的机会。所以,双塔改建最重要的一条主线就是:如何保留更多的信息在tower的final embedding中,从而有机会和对侧塔得到的embedding交叉?围绕着这条主线,勤劳的互联网打工人设计出很多的改进方案。
双塔重地,闲人免进
这种思路以张俊林大佬的SENet为代表。既然把信息“鱼龙混杂”一古脑地喂入塔,其中的噪声造成污染,导致很多细粒度的重要信息未能“幸存”到final dot product那一刻。SENet的思路就是,在将信息喂入塔之前,插入SEBlock。SEBlock动态学习各特征的重要性,增强重要信息,弱化甚至过滤掉原始特征中的噪声,从而减少信息在塔中传播过程中的污染与损耗,能够让可能多的重要信息“撑”到final dot product那一刻。
重要信息,走捷径,一步登顶
信息在塔中向上流动的过程,也是一个信息压缩的过程,不可避免地带来信息损耗。所以,我们很自然地想到,何不让那些重要信息抄近路,走捷径,把它们直接送到离final dot product更近的地方。
提到抄近路,大家自然而然地想到ResNet,如下图所示。
是喂入塔的原始信息,经过塔中的信息流动,到最后一层时已经损失了很多重要的、细粒度信息。 这时,我们将抄近路,送到最后一层与tower的输出融合(图中是element-wise add,但是显然那并不是唯一的融合方式),得到final embedding 这时的既包含了经过tower高度浓缩后的信息,又包含原始输入中的一些细粒度信息。特别是这些细粒度的重要信息,终于有了和对侧信息交互的机会。
抄近路的思路确定了,那么抄近路的方式,就五花八门,多种多样了。比如除了原始输入能够抄近路,塔中间的一些信息是不是也能抄一把?比如下图模样(BTW,有谁知道logo的出处吗?如果知道,咱俩除了调参炼丹,就有了另一个共同爱好^_^);既然信息在塔中流动过程中就已经损失了,重要信息没必要等到最后一刻再补充,补充到中间层也会大有帮助,就像马拉松选手的中途补水。
但是,这种抄近路的方式,也有其固有的缺点,就是会导致输入层的肿胀。比如原来tower final embedding是64维,你现在要将一些重要的、细粒度的信息也抄近路到最后一层。既然称这些信息是细粒度的,自然是未经过压缩提炼的,维度一般都很大,比如1024维。如果你将原来tower embedding与抄近路信息简单拼接,那么final embedding就会膨胀好几倍,会给线上存储、内存都带来巨大的压力。当然你可以将抄近路的信息,经过一层简单的线性映射,压缩到一个比较小的维度,但这也会引入额外的映射权重,严重时会导致训练时OOM。
所以,将所有原始信息无脑地抄近路,显然是行不通的。这就牵扯到另外一个问题:哪些信息值得抄近路?要回答这个问题,你当然可以跑一个SE block或其他什么算法,获得各特征的重要性。而从我的个人经验来看,我们要特别注意那些“极其个性化”(e.g., userId, itemId)的特征,和,对划分人群、物群有显著区分性的特征(e.g., 用户是新用户还是老用户?用户是否登陆?文章所使用的语言,等)。
条条大路通塔顶
这一思路有两个出发点:
第一个出发点与SENet是相同的:原始双塔,将所有信息一古脑地塞入一个塔,造成向上流动的信息通道拥挤不堪,各路信息相互干扰。 SENet的解决方法是“堵”,在喂入塔之前,就将噪声弱化甚至屏蔽掉,使塔内的信道变得宽敞,保证重要信息无损通过。 而另外一种思路是“疏”,大家没必要都挤一个塔向上流动,不同的信息(甚至是相同的信息)可以沿适合自己的塔向上流动浓缩,避免相互干扰。最后由每个小塔的embedding聚合成final embedding,与对侧的final embedding做dot或cosine。 另一个出发点,与“抄近路”的思路是类似的:我们不再相信(或者说,迷信)DNN的拟合能力。 传统双塔,只有一种信息上升通道,就是DNN。我估计很多同学有与我类似的经历,就是刚接触DNN的时候,听过这样一句话,“只要DNN足够复杂,能够模拟任意函数”。现在看来,这句话的可信性要大打折扣了,Google DCN的论文里宣称,"People generally consider DNNs as universal function approximators, that could potentially learn all kinds of feature interactions. However, recent studies found that DNNs are inefficient to even approximately model 2nd or 3rd-order feature crosses." 既然如此,我们也就没必要将宝都押在DNN这一种通道上。即使是相同的信息,也可以沿多种信息通道向上流动,最终将各通道得到的embedding聚合成final embedding,与对侧交互。
这种思路的典型代表,就是腾讯的并联双塔,“通过并联多个双塔结构增加双塔模型的宽度,来缓解双塔内积的瓶颈从而提升效果”
信息沿着MLP, DCN, FM, CIN这4种通道向上流动。每种通道各有所长,比如MLP是implicit feature cross,FM和DNN都属于explicit and bounded-degree feature cross,大家相互取长补短。 最终各通道的融合,这里是直接拼接,只不过各通道的embedding乘上一个可学习的系数,以形成一个logistic regression的效果。比如我们只有MLP和DCN两个通道,, ,则两侧点积时, = =
另外,涉及到并联双塔训练细节的是,
由于FM和DCN等结构,只能完成信息交叉,而无法信息压缩,所以只能喂入有限的重要特征,否则会引发维度膨胀。 虽然不同结构可能共享特征,但是它们却不共享这些特征的底层embedding。同一个特征,如果要同时喂入MLP和DCN,就必须定义两套embedding,供MLP和DCN分别加以训练。根据我之前的经验,分离embedding空间的确能够换来性能上的提升,但是也带来模型膨胀,给线上serving带来压力。
这种“多塔”思路的另一个代表,来自Facebook的今年最新论文《Que2Search: Fast and Accurate Query and Document Understanding for Search at Facebook》。这篇文章可以算是Facebook 2020年经典论文《Embedding-based Retrieval in Facebook Search》的后继。对比两篇论文的结构图,可以清晰看到“拆一个大塔为若干小塔”的思路变化。在Que2Search中,不同信息通过不同通道向上传递,比如country这样的categorical特征直接embedding,而文本信息则通过XLM。不同通道得到各自的embedding,再融合(fusion)生成final embedding,与对侧塔得到的final embedding计算cosine similarity。
而在融合多塔embedding生成final embedding时,Que2Search也提出Simple Attention Fusion方案,并通过实验证明,比传统的concatenation+mlp方案有效。Simple Attention Fusion的方案如下图所示,其中表示第i个通道得到的embedding, 'f'是各通道融合后的final embedding。
这种"多塔各自embedding + Attention Fusion"的方案,在淘宝的《Embedding-based Product Retrieval in Taobao Search》也有所体现。以user embedding ""为例
是由用户输入的query生成的embedding , , 分别代表由用户实时、短期、长期活动历史生成的embedding 最终的embedding实际上是由, , , 这4方面信息Self-Attention得到。(由于self-attention会得到一个embedding sequence,而不是一个embedding,因此文章作者增加一个dummy token CLS,并拿CLS embedding作为代表整个序列的最终输出,也算是transformer的传统trick了)
对面的塔儿看过来
之前的几种双塔改建方案,都是针对传统双塔“交叉太晚”这一缺点,目标是让更多有效信息“幸存”到final embedding里,“撑”到final dot product那一时刻,并通过净化输入、重要信息走捷径、拓宽信息上升通道等手段来实现这一思路。以上都是“亡羊补牢”的作法,也就是承认双塔就是交叉太晚,然后想方设法减轻这一缺点带来的信息损失。而美团在2021 KDD上发表的最新论文《A Dual Augmented Two-tower Model for Online Large-scale Recommendation》则选择和“交叉太晚”这一难题正面硬刚。我没有复现并尝试美团的思路,不过,我的确觉得这是一个非常有意思的想法,值得借鉴。
美团的“对偶增强双塔”的出发点是:
传统双塔不是交叉太晚吗?那我就发挥深度学习“无中生有”的优点,在user侧“造”出一个embedding模拟item tower的输出,并作为特征接入user tower的最底层; 同理,在item侧“造”出一个embedding模拟user tower的输出,并作为特征接入item tower的最底层; 这样一来,相当于item tower的输出成为了user tower的输入,user tower的输出成了item tower的输入,两侧信息从一开始就发生了交叉,交叉大大提前了。
具体实现简单描述如下。要了解详情的同学,请移步美团的原文
user塔的最底层输入: , 前边几项比较常规,代表是userId(e.g., 253)、地域(e.g., 上海)、性别(e.g., 男)等特征的embedding的拼接。 最后的就是user侧增强向量,通过user id在一个embedding matrix中查询得到,代表了来自item tower侧与该user进行过正向交互的所有items的信息 item塔的最底层输入: 前边几项比较常规,代表是itemId(e.g., 149)、价格(e.g., 10元)、类别(e.g., cate)等特征的embedding的拼接。 最后的就是item侧的增强向量,通过itemId在一个embedding matrix中查询得到,代表了来自user tower侧与该item进行过正向交互的所有users的信息 由原始输入/生成final embedding /的过程就比较传统了,就是分别喂入两侧的MLP “对偶增强双塔”的关键是如何学习好两侧的增强向量和,为此作者设计了Adaptive Mimic Mechanism辅助loss来专门训练它们。如下图所示,公式只在y=1时才发挥作用,目标是让,与所有与之正向交互过的item的final embedding 尽可能接近,反之对有类似作用。
通过以上方式,“对偶增强双塔”中的每个塔,学习到了对侧信息,并将其作为本侧塔的底层输入,从而使双塔之间从底层就发生交叉。
其实,美团的这种方式,有点蒸馏学习的味道。另外,要将“与某用户/物料交互过的物料/用户信息作为特征”,恐怕也不必这么麻烦,只要将user action list中的item id先embedding再pooling就可以。但是,真正有意义的是美团这种"让双塔相互深情对视"的创新思路,对我有非常强的借鉴意义,未来或许用得上。
再见,双塔?
如前所述,双塔模型的最大缺点就在于由于双塔分离,造成不能使用任何交叉特征和交叉架构,使得模型表达能力大打折扣。既然召回、粗排“苦双塔久矣”,何不鼓起勇气,大声和双塔说再见?
又是阿里妈妈一马当先,引领了这一波“去双塔”的潮流。阿里妈妈先是用TDM让召回告别了双塔,又用COLD+FSCD拉开了粗排向双塔告别的序幕。阿里的粗排算法,简单概括之,就是“粗排=特征筛选+精排”。而且遵循两步走策略:先训练一个模型,通过正则,筛选出“表达能力+计算耗时”性价比高的特征;再用筛选出的特征,训练出一个能够使用交叉特征+交叉结构的模型,用于粗排。以上只是我对原理的概括,欲了解详情,请移步阿里论文的原文。
那是不是说,现在我们就可以和“双塔”说bye bye了?正如歌中所唱的那样,“现在说再见还为时太早”。
首先,COLD+FSCD那样的去双塔模型,训练起来好训。但是部署上线的时候,要让每对儿<user, item>都过一遍模型,想想粗排那远超精排的候选集规模,其中涉及的线上serving的工程优化,恐怕非朝夕之功。 其次,双塔模型能够快速得到user & item embedding,即便不用于召回或粗排,也能够为精排提供特征。 最后,当有充足的线上资源,我们甚至可以改变目前的“召回→粗排→精排”三层架构。我们可以将“粗排”一分为二,用“双塔”作为“召回层的粗排”,用一个“能使用交叉特征+交叉结构的简化版精排”作为“排序层的粗排”,目标是筛选出更优质的物料喂入精排。
总结
本文是我和双塔模型死磕了6个月之后的心得体会。如前文所述,双塔分离,既是保障线上快速serving、以适应召回+粗排场景的优点,也是不能使用交叉特征与结构、导致两侧信息交叉过晚、制约模型表达能力的最大缺点。user&item两侧信息交叉得太晚,等到最终能够通过dot或cosine交叉的时候,user & item embedding已经高度浓缩,一些细粒度的信息已经在塔中被损耗,永远失去了与对侧信息交叉的机会。
为了克服这一缺点,业界同仁设计出许多改进方案。这些方案背后有一个共同的思路,就是减少信息沿塔上升过程中的损耗,让更多细粒度的重要信息能够“幸存”到final embedding中,能够“撑”到final dot product那一刻。然后我分析了Facebook、阿里、腾讯、新浪、美团等中外大厂的工作,看看我的中外同仁们如何从“净化输入”、“重要信息走捷径”、“拓宽信息上升通道”、“双塔相互模仿”等方面实现了这个思路,克服双塔的缺点,提升其性能。
正如我之前经常论述的“道”与“术”,“了解双塔的缺点,知道从哪里改进”是“道”,“怎么改进”是“术”,深刻理解“道”之后,才能将各大厂的“术”综合运用,甚至创造你自己的“术”,提升你自己模型的性能。比如,如果美团方案中的对偶增强向量真的那么重要,那么不将它们接入DNN最底层,而是直接抄近路到塔的最后一层,离final dot-product更近,是不是效果更佳?至于是否是这样,就要等GPU和AB平台告诉我们了。
最后,我也指出,尽管阿里妈妈已经在召回+粗排领域告别了“双塔”,但是现在和双塔彻底告别还为时过早,“双塔”模型仍然是我们广大算法同仁手中得心应手的一件兵器,你值得拥有。
回归之后的第一篇就这样写完了,啥时候写下一篇呢?写点啥呢?对了,问一下知乎的运营同学,我的“返航专属礼盒”怎么还没寄到?
- END -